# -*- coding:utf-8 -*-

import os
import importlib
import tensorflow as tf
from utils.io_utils import convert_abspath
from config.glob.global_pool import global_pool

"""
predict部分封装
"""


def load_model(sess):
    """
    加载模型文件
    :return:
    """

    saver = tf.train.Saver()
    load_model_dir = convert_abspath(global_pool.config.load_model_dir)  # 模型绝对路径
    if not os.path.exists(load_model_dir + '.index'):
        raise IOError('模型路径不存在')
    saver.restore(sess, load_model_dir)


def set_predict_in_model(model):
    """
    加入预测函数
    :param model:
    :return:
    """
    predict_mod = global_pool.config.predict_mod
    module = importlib.import_module(predict_mod)
    load_model(model.sess)
    setattr(model, 'predict', module.predict)
    return model


